import numpy as np
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d, gaussian_filter
import random
import math


def create_sequences_with_time(data, time_encoding, input_length, output_length):
    sequences = []
    time_slices = []
    for i in range(len(data) - input_length - output_length + 1):
        input_seq = data[i:i + input_length]
        output_seq = data[i + input_length:i + input_length + output_length]
        input_time = time_encoding[i:i + input_length + output_length]  # [input+output, time_dim]
        sequences.append((input_seq, output_seq))
        time_slices.append(input_time)
    return sequences, torch.stack(time_slices)  # second one is tensor: [N, input+output, time_dim]


def generate_spiral_dataset_backup2(n_trajectories=5000, total_steps=100, noise_std=0.01, seed=42, visualize=False):
    np.random.seed(seed)
    dataset = np.zeros((n_trajectories, total_steps, 2))
    time = np.linspace(0, 1, total_steps)

    # Set a shared initial phase (theta0) and starting amplitude (a)
    shared_theta0 = np.random.uniform(0, 2 * np.pi)
    shared_a = 1.0  # Fixed starting radius

    for i in range(n_trajectories):
        # Small random perturbations around shared values
        a = shared_a + np.random.uniform(-0.05, 0.05)  # very small variation
        b = np.random.uniform(0.5, 1.0)
        omega = np.random.uniform(2 * np.pi * 3, 2 * np.pi * 8)
        theta0 = shared_theta0 + np.random.uniform(-0.1, 0.1)  # small phase offset

        r = np.clip(a - b * time, a_min=1e-3, a_max=None)
        theta = omega * time + theta0

        x = r * np.cos(theta)
        y = r * np.sin(theta)

        x += np.random.randn(total_steps) * noise_std
        y += np.random.randn(total_steps) * noise_std

        dataset[i] = np.stack([x, y], axis=-1)

    if visualize:
        for i in range(min(5, n_trajectories)):
            plt.plot(dataset[i, :, 0], dataset[i, :, 1], label=f'Traj {i}')
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title("Sample Spiral Trajectories (with nearby initial points)")
        plt.axis("equal")
        plt.legend()
        plt.grid(True)
        plt.show()

    return dataset, time


def generate_load_dataset(n_trajectories=100, total_steps=144):
    import pandas as pd

    df = pd.read_csv('Flores_load_2008.csv', skiprows=1, header=None, names=['value'])
    raw_data = np.array(df['value'])
    min_val = raw_data.min()
    max_val = raw_data.max()
    normalize_data = 2 * (raw_data - min_val) / (max_val - min_val) - 1
    print(normalize_data.min(), normalize_data.max())
    dataset = normalize_data[0: n_trajectories * total_steps].reshape(n_trajectories, total_steps, 1)
    time = np.linspace(0, 1, total_steps)

    return dataset, time

def generate_spiral_dataset_backup(n_trajectories=5000, total_steps=100, noise_std=0.01, seed=42, visualize=False):
    np.random.seed(seed)
    dataset = np.zeros((n_trajectories, total_steps, 2))
    time = np.linspace(0, 1, total_steps)

    for i in range(n_trajectories):
        a = np.random.uniform(0.8, 1.2)
        b = np.random.uniform(0.5, 1.0)
        omega = np.random.uniform(2 * np.pi * 3, 2 * np.pi * 8)
        theta0 = np.random.uniform(0, 2 * np.pi)

        r = np.clip(a - b * time, a_min=1e-3, a_max=None)
        theta = omega * time + theta0

        x = r * np.cos(theta)
        y = r * np.sin(theta)

        x += np.random.randn(total_steps) * noise_std
        y += np.random.randn(total_steps) * noise_std

        dataset[i] = np.stack([x, y], axis=-1)

    if visualize:
        for i in range(5):
            plt.plot(dataset[i, :, 0], dataset[i, :, 1], label=f'Trajectory {i}')
        plt.xlabel("x"); plt.ylabel("y"); plt.title("Sample Spiral Trajectories")
        plt.axis("equal"); plt.legend(); plt.grid(True)
        plt.show()

    return dataset, time



def apply_random_mask(data, mask_ratio=0.3, seed=42):
    """
    Apply a random binary mask to the full data tensor.

    Args:
        data: [N, T, D]
        mask_ratio: float, portion of entries to mask (per sequence)
        seed: int

    Returns:
        masked_data: [N, T, D] with zeros at masked positions
        mask:        [N, T] binary (1=observed, 0=masked)
    """
    N, T, D = data.shape
    np.random.seed(seed)
    torch.manual_seed(seed)

    mask = torch.ones((N, T), dtype=torch.float32)
    num_mask = int(mask_ratio * T)

    for i in range(N):
        idx = np.random.choice(T, num_mask, replace=False)
        mask[i, idx] = 0.0

    data_tensor = torch.tensor(data, dtype=torch.float32)
    masked_data = data_tensor * mask.unsqueeze(-1)  # ✅ correct broadcasting

    return masked_data, mask
def create_windows_from_masked_data(masked_data, original_data, mask, time_all, input_length, output_length, stride=1):
    """
    Create sliding windows from masked input and full data.

    Args:
        masked_data: [N, T, D] - masked version of original data
        original_data: [N, T, D] - ground truth
        mask: [N, T] - binary mask
        time_all: [T] - global time
        input_length, output_length: window sizes
        stride: int

    Returns:
        x_win:       [B, input_length, D]
        y_inter_win: [B, input_length, D]
        y_extra_win: [B, output_length, D]
        t_inter_win: [B, input_length]
        t_extra_win: [B, output_length]
        mask_win:    [B, input_length]
    """
    x_win, y_inter_win, y_extra_win = [], [], []
    t_inter_win, t_extra_win, mask_win = [], [], []

    N, T, D = masked_data.shape
    t_all = torch.tensor(time_all, dtype=torch.float32)

    for i in range(N):
        for start in range(0, T - input_length - output_length + 1, stride):
            end_input = start + input_length
            end_output = end_input + output_length

            x_slice = masked_data[i, start:end_input, :]
            y_interp_slice = original_data[i, start:end_input, :]
            y_extrap_slice = original_data[i, end_input:end_output, :]
            mask_slice = mask[i, start:end_input]
            t_in = t_all[start:end_input]
            t_out = t_all[end_input:end_output]

            x_win.append(x_slice.unsqueeze(0))
            y_inter_win.append(y_interp_slice.unsqueeze(0))
            y_extra_win.append(y_extrap_slice.unsqueeze(0))
            mask_win.append(mask_slice.unsqueeze(0))
            t_inter_win.append(t_in.unsqueeze(0))
            t_extra_win.append(t_out.unsqueeze(0))

    return (
        torch.cat(x_win, dim=0),
        torch.cat(y_inter_win, dim=0),
        torch.cat(y_extra_win, dim=0),
        torch.cat(t_inter_win, dim=0),
        torch.cat(t_extra_win, dim=0),
        torch.cat(mask_win, dim=0)
    )
def create_moving_window_batches_v2(x_all, y_interp_all, y_extrap_all, t_normalized, mask_all,
                                    input_length, output_length, stride=1):
    """
    Create sliding window batches with aligned interpolation and extrapolation targets.

    Assumes:
        - x_all, y_interp_all: [N, T, D]
        - y_extrap_all: [N, T_extrap, D], where T_extrap = T - num_interp_steps
        - t_normalized: [T]
        - mask_all: [N, T]
        - input_length: encoder length
        - output_length: number of future steps after input window

    Returns:
        x_win:       [B, input_length, D]
        y_inter_win: [B, input_length, D]
        y_extra_win: [B, output_length, D]
        t_inter_win: [B, input_length]
        t_extra_win: [B, output_length]
        mask_win:    [B, input_length]
    """
    x_windows, y_inter_windows, y_extra_windows = [], [], []
    t_inter_windows, t_extra_windows, mask_windows = [], [], []

    N, T, D = x_all.shape
    T_extrap = y_extrap_all.shape[1]
    extrap_base_offset = T - T_extrap  # starting index of extrapolation in original timeline

    for i in range(N):
        for start in range(0, T - input_length - output_length + 1, stride):
            end_input = start + input_length
            end_output = end_input + output_length

            # Map output slice to extrap index
            extrap_start = end_input - extrap_base_offset
            extrap_end = end_output - extrap_base_offset

            if extrap_start < 0 or extrap_end > T_extrap:
                a = 1
                continue  # skip if extrapolation slice is out of bounds

            x_windows.append(x_all[i, start:end_input, :].unsqueeze(0))
            y_inter_windows.append(y_interp_all[i, start:end_input, :].unsqueeze(0))
            y_extra_windows.append(y_extrap_all[i, extrap_start:extrap_end, :].unsqueeze(0))
            mask_windows.append(mask_all[i, start:end_input].unsqueeze(0))
            t_inter_windows.append(t_normalized[start:end_input].unsqueeze(0))
            t_extra_windows.append(t_normalized[end_input:end_output].unsqueeze(0))

    return (
        torch.cat(x_windows, dim=0),
        torch.cat(y_inter_windows, dim=0),
        torch.cat(y_extra_windows, dim=0),
        torch.cat(t_inter_windows, dim=0),
        torch.cat(t_extra_windows, dim=0),
        torch.cat(mask_windows, dim=0)
    )

def generate_masked_sliding_windows(
    data, time_all, input_length, output_length, mask_ratio=0.3, stride=1, seed=42
):
    """
    First masks each trajectory, then slices into sliding windows.

    Args:
        data: [N, T, D]               - full trajectories
        time_all: [T]                - global time
        input_length: int
        output_length: int
        mask_ratio: float            - portion of time steps to mask in input window
        stride: int
        seed: int

    Returns:
        x_win:       [B, input_length, D]  - input with masked positions
        y_inter_win: [B, input_length, D]  - ground truth for interpolation
        y_extra_win: [B, output_length, D] - ground truth for extrapolation
        mask_win:    [B, input_length]     - binary mask
        t_inter_win: [B, input_length]     - time for input
        t_extra_win: [B, output_length]    - time for output
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    N, T, D = data.shape
    t_all = torch.tensor(time_all, dtype=torch.float32)

    x_win, y_inter_win, y_extra_win = [], [], []
    t_inter_win, t_extra_win, mask_win = [], [], []

    for i in range(N):
        # --- Step 1: Generate full mask once per trajectory ---
        full_mask = torch.ones(T, dtype=torch.float32)
        num_mask = int(mask_ratio * T)
        if num_mask > 0:
            idx_mask = np.random.choice(T, num_mask, replace=False)
            full_mask[idx_mask] = 0.0

        # --- Step 2: Create windows using this fixed mask ---
        for start in range(0, T - input_length - output_length + 1, stride):
            end_input = start + input_length
            end_output = end_input + output_length

            # Slices
            x_slice = torch.tensor(data[i, start:end_input, :], dtype=torch.float32)
            y_interp_slice = x_slice.clone()
            y_extra_slice = torch.tensor(data[i, end_input:end_output, :], dtype=torch.float32)

            # Apply mask on-the-fly (not to x_slice itself)
            mask_slice = full_mask[start:end_input]  # shape [input_length]

            # Times
            t_input = t_all[start:end_input]
            t_output = t_all[end_input:end_output]

            # Collect
            x_win.append(x_slice.unsqueeze(0))
            y_inter_win.append(y_interp_slice.unsqueeze(0))
            y_extra_win.append(y_extra_slice.unsqueeze(0))
            mask_win.append(mask_slice.unsqueeze(0))
            t_inter_win.append(t_input.unsqueeze(0))
            t_extra_win.append(t_output.unsqueeze(0))

    return (
        torch.cat(x_win, dim=0),
        torch.cat(y_inter_win, dim=0),
        torch.cat(y_extra_win, dim=0),
        torch.cat(mask_win, dim=0),
        torch.cat(t_inter_win, dim=0),
        torch.cat(t_extra_win, dim=0)
    )

def create_tagged_windows(data, time_all, input_length, output_length, mask_ratio=0.3, stride=1, seed=42):
    """
    Create moving windows with trajectory-wide masks and tag each window with its start index.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    N, T, D = data.shape
    t_all = torch.tensor(time_all, dtype=torch.float32)
    all_windows = []

    for i in range(N):
        full_mask = torch.ones(T, dtype=torch.float32)
        num_mask = int(mask_ratio * T)
        if num_mask > 0:
            idx_mask = np.random.choice(T, num_mask, replace=False)
            full_mask[idx_mask] = 0.0

        for start in range(0, T - input_length - output_length + 1, stride):
            end_input = start + input_length
            end_output = end_input + output_length

            x = torch.tensor(data[i, start:end_input], dtype=torch.float32)
            y_inter = x.clone()
            y_extra = torch.tensor(data[i, end_input:end_output], dtype=torch.float32)
            mask = full_mask[start:end_input]
            t_inter = t_all[start:end_input]
            t_extra = t_all[end_input:end_output]

            all_windows.append({
                'x': x,
                'y_inter': y_inter,
                'y_extra': y_extra,
                'mask': mask,
                't_inter': t_inter,
                't_extra': t_extra,
                'start_idx': start
            })

    return all_windows

from collections import defaultdict

def bucket_windows_by_start(windows):
    """
    Group windows into buckets based on start index.
    """
    buckets = defaultdict(list)
    for win in windows:
        buckets[win['start_idx']].append(win)
    return buckets


from torch.utils.data import TensorDataset, DataLoader

def build_bucket_dataloaders(buckets, batch_size):
    """
    Build one DataLoader per time bucket (same start index).
    """
    dataloaders = {}

    for start_idx, win_list in buckets.items():
        x_batch = torch.stack([w['x'] for w in win_list])
        y_inter_batch = torch.stack([w['y_inter'] for w in win_list])
        y_extra_batch = torch.stack([w['y_extra'] for w in win_list])
        mask_batch = torch.stack([w['mask'] for w in win_list])
        t_inter = win_list[0]['t_inter']
        t_extra = win_list[0]['t_extra']

        dataset = TensorDataset(x_batch, y_inter_batch, y_extra_batch, mask_batch)
        dataloaders[start_idx] = {
            'loader': DataLoader(dataset, batch_size=batch_size, shuffle=True),
            't_inter': t_inter,
            't_extra': t_extra
        }

    return dataloaders

def prepare_masked_full_test_set(data_test, time_all, mask_ratio=0.3, seed=999):
    """
    Apply consistent masking to full test trajectories.

    Args:
        data_test: [N, T, D]
        time_all:  [T] global normalized time vector
        mask_ratio: fraction of masked time steps (for interpolation)
        seed: random seed for reproducibility

    Returns:
        x_test: masked test input [N, T, D] (0.0 at masked positions)
        y_test: ground truth       [N, T, D]
        mask_test: binary mask     [N, T]   (1=observed, 0=masked)
        t_test: shared time        [T]
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    N, T, D = data_test.shape
    y_test = torch.tensor(data_test, dtype=torch.float32)
    t_test = torch.tensor(time_all, dtype=torch.float32)

    mask_test = torch.ones((N, T), dtype=torch.float32)
    num_mask = int(mask_ratio * T)
    for i in range(N):
        idx = np.random.choice(T, num_mask, replace=False)
        mask_test[i, idx] = 0.0

    x_test = y_test * mask_test.unsqueeze(-1)

    # torch.save({
    #     'x_test': x_test,
    #     'y_test': y_test,
    #     'mask_test': mask_test,
    #     't_test': t_test
    # }, save_path)
    # print(f"Masked test set saved to {save_path}")

    return x_test, y_test, mask_test, t_test

def prepare_full_trajectory_test_data(data, time_all, mask_ratio=0.3, seed=42):
    """
    Prepare full-length test data for evaluation (no windows).

    Args:
        data: [N, T, D]
        time_all: [T] - normalized time
        mask_ratio: fraction of time points to mask
        seed: int

    Returns:
        x_test: [N, T, D] - masked version of data
        y_test: [N, T, D] - ground truth
        mask:   [N, T]    - 1 for observed, 0 for masked
        t_test: [T]       - shared time vector
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    N, T, D = data.shape
    t_test = torch.tensor(time_all, dtype=torch.float32)
    y_test = torch.tensor(data, dtype=torch.float32)

    mask = torch.ones((N, T), dtype=torch.float32)
    num_mask = int(mask_ratio * T)
    for i in range(N):
        idx = np.random.choice(T, num_mask, replace=False)
        mask[i, idx] = 0.0

    # x_test = y_test * mask.unsqueeze(-1)
    return y_test, y_test, mask, t_test


def reconstruct_full_test_trajectories(
    model,
    data_test,
    time_all,
    input_length,
    stride=1,
    mask_ratio=0.3,
    visual_traj_index=0,
    device='cpu'
):
    model.eval()
    N, T, D = data_test.shape
    t_tensor = torch.tensor(time_all, dtype=torch.float32, device=device)
    data_tensor = torch.tensor(data_test, dtype=torch.float32, device=device)

    # Initialize full prediction, count, mask, and ground truth tensors
    pred_full = torch.zeros(N, T, D, device=device)
    pred_count = torch.zeros(N, T, D, device=device)
    mask_full = torch.ones(N, T, dtype=torch.float32, device=device)
    gt_full = torch.zeros(N, T, D, device=device)

    gate_list = []

    with torch.no_grad():
        for i in range(N):
            # Generate a single consistent mask for the whole trajectory
            mask_trajectory = torch.ones(T, dtype=torch.float32, device=device)
            num_mask = int(mask_ratio * T)
            idx = torch.randperm(T)[:num_mask]
            mask_trajectory[idx] = 0.0
            mask_full[i] = mask_trajectory  # save for plotting

            for start in range(0, T - input_length + 1, stride):
                end = start + input_length

                x_window = data_tensor[i, start:end].unsqueeze(0)       # [1, L, D]
                t_window = t_tensor[start:end]                          # [L]
                mask_window = mask_trajectory[start:end].unsqueeze(0)  # [1, L]

                pred_window, z_seq, gate_seq, *_ = model(x_window, t_window, mask_window)  # [1, L, D]

                pred_full[i, start:end] += pred_window[0]
                pred_count[i, start:end] += 1

                gt_full[i, start:end] = data_tensor[i, start:end]  # direct assignment

                gate_list.append(gate_seq.cpu())

    # Normalize overlapping predictions
    pred_count[pred_count == 0] = 1.0
    pred_full = pred_full / pred_count

    if gate_list:
        gate_tensor = torch.cat(gate_list, dim=0)  # [num_windows, L, G]
        avg_gate = gate_tensor.mean(dim=(0, 1))    # [G]
    else:
        avg_gate = None

    return pred_full.cpu(), gt_full.cpu(), mask_full.cpu(), avg_gate


def extract_latent_flow_nonoverlap(
    model,
    data_test,
    time_all,
    input_length,
    visual_traj_indices=[0],
    mask_ratio=0.3,
    device='cpu'
):
    """
    Extract latent z-sequences for specified trajectories using non-overlapping windows.
    """
    model.eval()
    N, T, D = data_test.shape
    t_tensor = torch.tensor(time_all, dtype=torch.float32, device=device)
    data_tensor = torch.tensor(data_test, dtype=torch.float32, device=device)

    all_z_list = []
    all_t_list = []

    latent_dim = model.latent_dim
    num_mask = int(mask_ratio * input_length)

    with torch.no_grad():
        for i in visual_traj_indices:
            for start in range(0, T - input_length + 1, input_length):  # non-overlapping
                end = start + input_length
                x_window = data_tensor[i, start:end].unsqueeze(0)  # [1, L, D]

                mask = torch.ones(input_length, dtype=torch.float32, device=device)
                if num_mask > 0:
                    idx = torch.randperm(input_length)[:num_mask]
                    mask[idx] = 0.0
                mask_window = mask.unsqueeze(0)

                t_window = t_tensor[start:end]

                _, z_seq, *_ = model(x_window, t_window, mask_window)  # [1, L, latent_dim]
                all_z_list.append(z_seq[0].cpu())         # [L, latent_dim]
                all_t_list.append(t_window.cpu())         # [L]

    z_continuous = torch.cat(all_z_list, dim=0) if all_z_list else None
    t_continuous = torch.cat(all_t_list, dim=0) if all_t_list else None

    return z_continuous, t_continuous

def generate_spiral_dataset_backup3(n_trajectories=5000, total_steps=100, noise_std=0.01, seed=42, visualize=False):
    np.random.seed(seed)
    dataset = np.zeros((n_trajectories, total_steps, 2))
    time = np.linspace(0, 1, total_steps)

    # Shared base values
    shared_theta0 = np.random.uniform(0, 2 * np.pi)
    shared_a = 1.0  # Shared starting radius
    shared_b = 0.7  # Shared scaling factor (spiral tightness)
    shared_omega = 2 * np.pi * 5  # Shared rotational speed (5 full circles over [0,1])

    for i in range(n_trajectories):
        # Small perturbations around shared values
        a = shared_a + np.random.uniform(-0.05, 0.05)
        b = shared_b + np.random.uniform(-0.05, 0.05)   # Tighter variation
        omega = shared_omega + np.random.uniform(-2.0, 2.0)  # Small variation in angular velocity
        theta0 = shared_theta0 + np.random.uniform(-0.1, 0.1)

        r = np.clip(a - b * time, a_min=1e-3, a_max=None)
        theta = omega * time + theta0

        x = r * np.cos(theta)
        y = r * np.sin(theta)

        x += np.random.randn(total_steps) * noise_std
        y += np.random.randn(total_steps) * noise_std

        dataset[i] = np.stack([x, y], axis=-1)

    if visualize:
        for i in range(min(5, n_trajectories)):
            plt.plot(dataset[i, :, 0], dataset[i, :, 1], label=f'Traj {i}')
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title("Sample Spiral Trajectories (Small Initial + Rotation + Scaling Variations)")
        plt.axis("equal")
        plt.legend()
        plt.grid(True)
        plt.show()

    return dataset, time

def generate_spiral_dataset(n_trajectories=5000, total_steps=100, noise_std=0.01, seed=42, visualize=False):
    """
    Generate spirals where the radius shrinks to zero (center) over time.
    """
    np.random.seed(seed)
    dataset = np.zeros((n_trajectories, total_steps, 2))
    time = np.linspace(0, 1, total_steps)

    for i in range(n_trajectories):
        a = np.random.uniform(0.9, 1.1)  # initial radius closer around 1
        b = np.random.uniform(1.5, 2.0)  # stronger shrinking factor

        omega = np.random.uniform(2 * np.pi * 3, 2 * np.pi * 5)  # moderate rotation speed
        theta0 = np.random.uniform(0, 2 * np.pi)

        # r(t) should decay faster to zero
        r = np.clip(a * (1 - time)**b, a_min=1e-3, a_max=None)  # exponential-like shrinkage
        theta = omega * time + theta0

        x = r * np.cos(theta)
        y = r * np.sin(theta)

        # Add small noise
        x += np.random.randn(total_steps) * noise_std
        y += np.random.randn(total_steps) * noise_std

        dataset[i] = np.stack([x, y], axis=-1)

    if visualize:
        for i in range(min(5, n_trajectories)):
            plt.plot(dataset[i, :, 0], dataset[i, :, 1], label=f'Trajectory {i}')
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title("Sample Shrinking Spiral Trajectories")
        plt.axis("equal")
        plt.legend()
        plt.grid(True)
        plt.show()

    return dataset, time

def prepare_interpolation_extrapolation_data(data, num_interp_steps=100, mask_ratio=0.3):
    """
    Prepare masked input, ground truth for interpolation and extrapolation.

    Args:
        data (ndarray): shape [N, T, D], N trajectories each with T time steps and D dimensions.
        num_interp_steps (int): Number of steps used for interpolation (typically the first T//2).
        mask_ratio (float): Fraction of interpolation points to mask (set as unobserved).

    Returns:
        x_interp_masked: [N, num_interp_steps, D] input with masking applied.
        y_interp_true:   [N, num_interp_steps, D] ground truth values for interpolation region.
        y_extrap_true:   [N, T - num_interp_steps, D] target for extrapolation.
        t_normalized:    [T] normalized time in [0,1].
        mask:            [N, num_interp_steps] binary mask where 1=observed, 0=masked.
    """
    N, T, D = data.shape
    assert num_interp_steps <= T, "Interpolation steps exceed total time steps"


    # Interpolation input and target
    x_interp = torch.tensor(data[:, :num_interp_steps, :], dtype=torch.float32)
    y_interp_true = x_interp.clone()

    # Generate binary masks: 1 = observed, 0 = masked
    mask = torch.ones((N, num_interp_steps), dtype=torch.float32)
    num_mask = int(mask_ratio * num_interp_steps)


    # Time vector (shared for all trajectories), normalized to [0, 1]
    t = np.linspace(0, 1, T)
    t_normalized = torch.tensor(t, dtype=torch.float32)


    for i in range(N):
        idx = np.random.choice(num_interp_steps, num_mask, replace=False)
        mask[i, idx] = 0.0

    # Extrapolation target
    y_extrap_true = torch.tensor(data[:, num_interp_steps:, :], dtype=torch.float32)

    return x_interp, y_interp_true, y_extrap_true, t_normalized, mask

def visualize_masked_trajectories(x_interp, y_interp_true, mask, num_trajs_to_plot=2):
    """
    Visualize masked vs ground-truth 2D curves for a few trajectories.

    Args:
        x_interp: [N, num_interp_steps, D] masked input
        y_interp_true: [N, num_interp_steps, D] ground truth (before masking)
        mask: [N, num_interp_steps] binary mask (1=observed, 0=masked)
        num_trajs_to_plot: how many trajectories to visualize
    """

    N, num_steps, D = x_interp.shape
    assert D == 2, "This visualization assumes 2D trajectories."

    traj_indices = np.random.choice(N, num_trajs_to_plot, replace=False)

    plt.figure(figsize=(6 * num_trajs_to_plot, 5))
    for i, idx in enumerate(traj_indices):
        gt_traj = y_interp_true[idx]  # [num_steps, 2]
        masked_traj = x_interp[idx]
        traj_mask = mask[idx]  # [num_steps]

        observed_points = masked_traj[traj_mask == 1]

        plt.subplot(1, num_trajs_to_plot, i + 1)
        plt.plot(gt_traj[:, 0], gt_traj[:, 1], label='Ground Truth Trajectory', color='blue')
        plt.scatter(observed_points[:, 0], observed_points[:, 1], label='Observed Points', color='red', marker='o')
        plt.title(f"Trajectory {idx}")
        plt.xlabel("x1")
        plt.ylabel("x2")
        plt.legend()
        plt.axis('equal')
        plt.grid(True)

    plt.tight_layout()
    plt.show()

def generate_inward_spiral_with_rotation_scaling(
    total_steps=1000,
    a=1.0,
    b=0.8,
    omega=2 * np.pi * 6,  # higher = more rotations
    noise_std=0.0,
    visualize=True
):
    """
    Generate a smooth inward spiral with rotational and scaling symmetry.

    Args:
        total_steps (int): Number of time steps.
        a (float): Initial radius.
        b (float): Radial decay coefficient (a - b * t).
        omega (float): Angular frequency.
        noise_std (float): Optional Gaussian noise.
        visualize (bool): Plot the spiral and time series.

    Returns:
        data (ndarray): [T, 2] spiral points (x, y).
        time (ndarray): [T] normalized time [0,1].
        labels (ndarray): [T] all ones.
    """
    t = np.linspace(0, 1, total_steps)
    r = a - b * t
    r = np.clip(r, a_min=1e-3, a_max=None)  # prevent collapse to 0
    theta = omega * t

    x = r * np.cos(theta)
    y = r * np.sin(theta)

    if noise_std > 0:
        x += np.random.randn(total_steps) * noise_std
        y += np.random.randn(total_steps) * noise_std

    data = np.stack([x, y], axis=1)
    labels = np.ones(total_steps)

    if visualize:
        # Spiral in XY
        plt.figure(figsize=(6, 6))
        plt.plot(x, y, lw=1.5, label='Inward Spiral')
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title("Inward Spiral (Rotation + Scaling)")
        plt.axis("equal")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()

        # x(t), y(t)
        plt.figure(figsize=(10, 4))
        plt.plot(t, x, label='x(t)', alpha=0.9)
        plt.plot(t, y, label='y(t)', alpha=0.9)
        plt.xlabel("Time t")
        plt.ylabel("Value")
        plt.title("Inward Spiral Components over Time")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()

    return data, t, labels

def generate_rotational_symmetric_data_with_region_labels_3d(
    total_steps=10000,
    rotation_axis=np.array([1, 1, 1]),
    rotation_angle=np.pi / 100,
    noise_std=0.0,
    num_samples=None,
    visualize=True
):
    rotation_axis = rotation_axis/np.linalg.norm(rotation_axis)
    def rotation_matrix(axis, theta):
        axis = axis / np.linalg.norm(axis)
        a = np.cos(theta / 2)
        b, c, d = -axis * np.sin(theta / 2)
        return np.array([
            [a*a + b*b - c*c - d*d, 2*(b*c - a*d),     2*(b*d + a*c)],
            [2*(b*c + a*d),     a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
            [2*(b*d - a*c),     2*(c*d + a*b),     a*a + d*d - b*b - c*c]
        ])

    R = rotation_matrix(rotation_axis, rotation_angle)

    # Start at a point not aligned with the rotation axis
    state = np.array([1.0, 0.5, 0.2])
    state /= np.linalg.norm(state)

    data = np.zeros((total_steps, 3))
    for t in range(total_steps):
        state = R @ state
        if noise_std > 0:
            state += np.random.randn(3) * noise_std
            state /= np.linalg.norm(state)
        data[t] = state

    time = np.arange(total_steps)
    labels = np.ones_like(time)

    if num_samples is not None and num_samples < total_steps:
        indices = np.sort(np.random.choice(total_steps, size=num_samples, replace=False))
        data, time, labels = data[indices], time[indices], labels[indices]

    if visualize:
        # 3D plot
        fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(data[:, 0], data[:, 1], data[:, 2], lw=0.8, c='blue')
        ax.set_title('3D Rotational Motion on Sphere')
        ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
        ax.set_box_aspect([1, 1, 1])
        plt.tight_layout()
        plt.show()

        # Time series
        plt.figure(figsize=(10, 4))
        plt.plot(time, data[:, 0], label='x(t)', alpha=0.8)
        plt.plot(time, data[:, 1], label='y(t)', alpha=0.8)
        plt.plot(time, data[:, 2], label='z(t)', alpha=0.8)
        plt.title('Time Series of 3D Spherical Rotation')
        plt.xlabel('Time'); plt.ylabel('Value'); plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.show()

    return data, time, labels


def generate_PV_dataset(n_trajectories=135, total_steps=31 * 4):
    import os

    folder = "onemin-Ground"
    cols = ["ShuntPDC_kW_Avg_4",
            "ShuntPDC_kW_Avg_5",
            "ShuntPDC_kW_Avg_6",
            "ShuntPDC_kW_Avg_7"]

    traj_list = []
    for fname in os.listdir(folder):
        df = pd.read_csv(os.path.join(folder, fname), usecols=cols)
        df = df.iloc[460:1000]  # length 540
        df = df.iloc[::4].reset_index(drop=True)

        for c in cols:
            traj_list.append(df[c].values)  # length 540/4=135

    dataset = np.stack(traj_list, axis=0)  # shape (124, 135)
    dataset = dataset[..., np.newaxis]  # shape (124, 135, 1)

    # normalize
    min_val = dataset.min()
    max_val = dataset.max()
    dataset = 2 * (dataset - min_val) / (max_val - min_val) - 1
    return dataset[:, 0:100, :], np.linspace(0, 1, total_steps)[0:100]


import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt


def generate_spherical_rotation_with_fixed_R(
        total_steps=10000,
        num_samples=None,
        visualize=True
):
    """
    Generate 3D data on the unit sphere using a fixed, non-diagonal rotation matrix.

    Args:
        total_steps (int): Length of the full rotational sequence.
        num_samples (int or None): If set, randomly sample this many points from the full data.
        visualize (bool): Whether to show trajectory and time series plots.

    Returns:
        data (ndarray): [N, 3] points on unit sphere.
        time (ndarray): [N] time indices.
        labels (ndarray): [N] all ones.
    """

    # Fixed, non-diagonal, orthogonal rotation matrix with det = 1
    R = np.array([
        [0.36, 0.48, -0.8],
        [-0.8, 0.60, 0.0],
        [0.48, 0.64, 0.6]
    ])

    # Confirm it's a valid SO(3) matrix
    assert np.allclose(R.T @ R, np.eye(3), atol=1e-6), "R is not orthogonal!"
    assert np.isclose(np.linalg.det(R), 1.0, atol=1e-6), "R is not in SO(3)!"

    # Start from a point on the sphere
    state = np.array([1.0, 0.0, 0.0])
    data = np.zeros((total_steps, 3))

    stride = 1  # apply every 5 steps
    for t in range(total_steps):
        if t % stride == 0:
            state = R @ state
        state /= np.linalg.norm(state)
        data[t] = state
    time = np.arange(total_steps)
    labels = np.ones_like(time)

    # Subsample if requested
    if num_samples is not None and num_samples < total_steps:
        indices = np.sort(np.random.choice(total_steps, size=num_samples, replace=False))
        data, time, labels = data[indices], time[indices], labels[indices]

    if visualize:
        # 3D trajectory plot
        fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(data[:, 0], data[:, 1], data[:, 2], lw=0.8)
        ax.set_title("3D Rotation on Unit Sphere (Fixed R)")
        ax.set_xlabel("X");
        ax.set_ylabel("Y");
        ax.set_zlabel("Z")
        ax.set_box_aspect([1, 1, 1])
        plt.tight_layout()
        plt.show()

        # Time series plot
        plt.figure(figsize=(10, 4))
        plt.plot(time, data[:, 0], label='x(t)')
        plt.plot(time, data[:, 1], label='y(t)')
        plt.plot(time, data[:, 2], label='z(t)')
        plt.title("Time Series of Latent Rotation on Sphere")
        plt.xlabel("Time step");
        plt.ylabel("Value")
        plt.grid(True);
        plt.legend();
        plt.tight_layout()
        plt.show()

    return data, time, labels


def generate_rotational_symmetric_data(time_steps=1000, rotation_angle=np.pi / 50, noise=0.01):
    """
    Generate synthetic 2D data with rotational symmetry.

    Parameters:
        time_steps (int): Number of time steps for the data.
        rotation_angle (float): The angle of rotation in radians at each time step.
        noise (float): Standard deviation of Gaussian noise added to the data.

    Returns:
        data (ndarray): Generated 2D data of shape (time_steps, 2).
        time (ndarray): Array of time indices of shape (time_steps,).
    """
    # Initialize the data array
    data = np.zeros((time_steps, 2))

    # Initialize the state (starting point)
    state = np.random.randn(2)

    # Define the 2D rotation matrix
    rotation_matrix = np.array([
        [np.cos(rotation_angle), -np.sin(rotation_angle)],
        [np.sin(rotation_angle), np.cos(rotation_angle)]
    ])

    # Generate the data
    for t in range(time_steps):
        # Apply rotation
        state = np.dot(rotation_matrix, state)

        # Add noise
        noisy_state = state + noise * np.random.randn(2)

        # Store the noisy state
        data[t] = noisy_state

    # Create a time array
    time = np.arange(time_steps)

    return data

def generate_rotational_symmetric_data_with_region_labels(
    time_steps=1000,
    rotation_angle=np.pi / 50,
    noise=0.01,
    unsafe_angle_range=(0, np.pi / 4)
):
    """
    Generate synthetic 2D data with rotational symmetry and define safe/unsafe states based on regions.

    Parameters:
        time_steps (int): Number of time steps for the data.
        rotation_angle (float): The angle of rotation in radians at each time step.
        noise (float): Standard deviation of Gaussian noise added to the data.
        unsafe_angle_range (tuple): Tuple defining the range of angles (in radians) considered unsafe.

    Returns:
        data (ndarray): Generated 2D data of shape (time_steps, 2).
        time (ndarray): Array of time indices of shape (time_steps,).
        labels (ndarray): Array of 0 (safe) or 1 (unsafe) of shape (time_steps,).
    """
    # Initialize the data array
    data = np.zeros((time_steps, 2))

    # Initialize the state (starting point)
    state = np.random.randn(2)

    # Define the 2D rotation matrix
    rotation_matrix = np.array([
        [np.cos(rotation_angle), -np.sin(rotation_angle)],
        [np.sin(rotation_angle), np.cos(rotation_angle)]
    ])

    # Generate the data
    for t in range(time_steps):
        # Apply rotation
        state = np.dot(rotation_matrix, state)

        # Add noise
        noisy_state = state + noise * np.random.randn(2)

        # Store the noisy state
        data[t] = noisy_state

    # Create a time array
    time = np.arange(time_steps)

    # Compute the angle (in radians) of each state
    angles = np.arctan2(data[:, 1], data[:, 0])  # atan2 ensures angles are in [-π, π]

    # Normalize angles to [0, 2π]
    angles = (angles + 2 * np.pi) % (2 * np.pi)

    # Define unsafe region
    unsafe_min, unsafe_max = unsafe_angle_range
    labels = np.where((angles >= unsafe_min) & (angles <= unsafe_max), 1, 0)

    return data, time, labels

def generate_time_series(length, noise=0.01):
    time = np.linspace(0, 10, length)
    data1 = np.sin(time) + noise * np.random.randn(length)
    data2 = np.cos(time) + noise * np.random.randn(length)
    data = np.stack([data1, data2], axis=1)
    return time, data

def gene_data():
    h = np.random.rand(2, 1)
    theta = 0.05
    R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    T = 1000
    data = np.zeros((T, 2))

    for t in range(T):
        h = R @ h + 0.02 * h
        data[t, 0] = h[0]
        data[t, 1] = h[1]

    return data

def create_sequences(data, input_length, output_length):
    sequences = []
    for i in range(len(data) - input_length - output_length + 1):
        input_seq = data[i:i + input_length]
        output_seq = data[i + input_length:i + input_length + output_length]
        sequences.append((input_seq, output_seq))
    return sequences

def fourier_encode(t, num_bands=4, max_freq=5.0):
    device = t.device
    freq_bands = torch.linspace(0, 0, steps=num_bands).to(device)  # [num_bands]
    t = t * 2 * math.pi  # Convert to radians
    encoded = [torch.sin(freq * t) for freq in freq_bands] + \
              [torch.cos(freq * t) for freq in freq_bands]
    return torch.cat(encoded, dim=-1)  # [B, 2 * num_bands]

def normalize_data(data):
    scaler = MinMaxScaler()
    return scaler.fit_transform(data), scaler

def get_load_oncor():
    import ast
    path = 'C:\\Simulation\\Geometric AI\\Latent_Manifold\\Data\\CNEXP0006_kWh_two_week_data.csv'
    df_load = pd.read_csv(path)
    load = df_load.iloc[0:200, :]

    raw_data = [x for x in load["kWh"].apply(ast.literal_eval)]
    raw_time = [x for x in load["time"].apply(ast.literal_eval)]
    # print(time)
    # exit()

    # Find the maximum length of lists
    max_length = max(len(lst) for lst in raw_data)

    # Filter the lists that have the maximum length
    filtered_data = [lst for lst in raw_data if len(lst) == max_length]
    filtered_time = [lst for lst in raw_time if len(lst) == max_length]

    filtered_data = np.array(filtered_data)
    filtered_time = np.array(filtered_time)
    # eliminate rows with many zeros
    # Define a threshold for the maximum number of 0s allowed in a row
    threshold = 1  # This means we allow up to 1 zero in a row
    # Count the number of 0s in each row
    zero_counts = np.count_nonzero(filtered_data == 0, axis=1)

    # Filter out rows that have more than 'threshold' number of 0s
    data = filtered_data[zero_counts <= threshold]
    time = filtered_time[zero_counts <= threshold]

    # Save the array to a CSV file
    np.savetxt("Load_profile1.csv", data, delimiter=",", fmt="%s")
    time = time[0, :].reshape(-1, )
    time = np.array([1 * i for i in range(len(time))])

    data = torch.tensor(data, dtype=torch.float32).T

    # data = data[500: 1000]
    # time = time[500: 1000]
    data = data[::, :]
    time = time[::, ]

    data = gaussian_filter(data, 1)

    # data = np.tile(data, (1, 1, 1))
    # data = normalize_2d(data)

#    data = gaussian_filter(data, 5)  #### todo gaussian fliter

    t = torch.tensor(time, dtype=torch.float)
    y = torch.tensor(data, dtype=torch.float)

    L_high, L_low = len(t), 20 #######
    t_low_index = np.linspace(0, L_high - 1, L_low, dtype=int)

    t_high = t.numpy()
    t_high = t_high - t_high[0]
    t_low = t_high[t_low_index]

    y_high_split = y[:, 45:47].numpy() #16:17
    y_low_high_res = y[:, 20:22].numpy()  #20:21


    y_low_split = y_low_high_res[t_low_index, :]
    LR_interval = L_high // L_low

    print("Low rate: {}/min".format(1/(t_low[1] - t_low[0])))
    print("High rate: {}/min".format(1 / (t_high[1] - t_high[0])))

    print("t_high[0] = ", t_high[0])
    print("t_high[-1] = ", t_high[-1])
    print("t_high[1] - t_high[0] = {} min".format(t_high[1] - t_high[0]))
    print("t_low[1] - t_low[0] = {} min".format(t_low[1] - t_low[0]))

    print(t_high.shape, y_high_split.shape, t_low.shape, y_low_split.shape, y_low_high_res.shape, t_low_index.shape)
    print("LR_interval", LR_interval)



    return  y_high_split

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def generate_glycolytic_dataset(n_trajectories=100, total_steps=200, noise=0.0):
    from scipy.integrate import odeint as sp_odeint

    # Define the ODE system
    def selkov_oscillator(z, t):
        # Define Sel’kov oscillator parameters (unchanged)
        a = 0.75  # Input flux
        b = 0.1  # Output flux

        x1, x2 = z
        dx1dt = a - b * x1 - x1 * (x2 ** 2)
        dx2dt = b * x1 - x2 + x1 * (x2 ** 2)
        return [dx1dt, dx2dt]

    # Preallocate
    dataset = np.zeros((n_trajectories, total_steps, 2))
    t_full = np.linspace(0, 100, 1001)
    idx = np.linspace(0, len(t_full) - 1, 200, dtype=int)  # indices for downsampling

    # Generate
    for i in range(n_trajectories):
        # random initialization
        z0 = np.random.uniform(0.8, 1.2, size=2)

        sol = sp_odeint(selkov_oscillator, z0, t_full)
        # downsample to 200 points
        dataset[i] = sol[idx]

    # normalize
    min_val = dataset.min()
    max_val = dataset.max()
    dataset = 2 * (dataset - min_val) / (max_val - min_val) - 1
    return dataset, np.linspace(0, 1, total_steps)


def generate_lotka_dataset(n_trajectories=100, total_steps=200, noise=0.0):
    from scipy.integrate import odeint as sp_odeint

    # Define the ODE system
    def lotka_volterra_equations(z, t_):
        # Lotka–Volterra parameters (unchanged)
        alpha = 0.1  # prey growth rate
        beta = 0.02  # rate of predation
        gamma = 0.3  # predator death rate
        delta = 0.01  # predator growth from eating prey

        x, y = z
        dxdt = alpha * x - beta * x * y
        dydt = delta * x * y - gamma * y
        return [dxdt, dydt]

    # Preallocate
    dataset = np.zeros((n_trajectories, total_steps, 2))
    t_full = np.linspace(0, 100, 1001)
    idx = np.linspace(0, len(t_full) - 1, 200, dtype=int)  # indices for downsampling

    # Generate
    for i in range(n_trajectories):
        # random initialization
        x0 = np.random.uniform(40 - 5, 40 + 5)
        y0 = np.random.uniform(2 - 0.25, 2 + 0.25)

        sol = sp_odeint(lotka_volterra_equations, [x0, y0], t_full)
        # downsample to 200 points
        dataset[i] = sol[idx]

    # normalize
    min_val = dataset.min()
    max_val = dataset.max()
    dataset = 2 * (dataset - min_val) / (max_val - min_val) - 1
    return dataset, np.linspace(0, 1, total_steps)


def generate_power_event_dataset(n_trajectories=100, total_steps=100):
    df = pd.read_csv('Event5.csv', usecols=lambda name: name.endswith('VP.A'))
    select_traj = list(range(n_trajectories + 1))
    select_traj.remove(23)  # 23 has some outliers
    df = df.iloc[156:256, select_traj]

    # reshape
    dataset = np.array(df).reshape(n_trajectories, total_steps, 1)
    # normalize
    min_val = dataset.min()
    max_val = dataset.max()
    dataset = 2 * (dataset - min_val) / (max_val - min_val) - 1
    return dataset, np.linspace(0, 1, total_steps)